# -*- coding: utf-8 -*-


from pathlib import Path
import numpy as np
from osgeo import gdal
import glob

#  初始化每个类的数目

for model_path in Path("./model").glob("*.pth"):
    if "swa" in model_path.name:
        continue
    model_path = str(model_path)
    output_dir = "../prediction_result" + Path(model_path).stem
    test_dir = "../data/suichang_round1_test"
    Cultivatedland_num = 0
    Woodland_num = 0
    Grass_num = 0
    Road_num = 0
    UrbanConstructionLand_num = 0
    RuralConstructionLand_num = 0
    IndustrialLand_num = 0
    Structure_num = 0
    Water_num = 0
    NakedLand_num = 0
    print(output_dir)
    label_paths = glob.glob(f"{output_dir}/*.png")

    correct = 0

    for label_path in label_paths:
        label = gdal.Open(label_path).ReadAsArray(0, 0, 256, 256)
        ans = gdal.Open(test_dir + "/" + Path(label_path).name).ReadAsArray(
            0, 0, 256, 256
        )
        correct += np.sum(label == ans) / 256 / 256

    # 这两行代码解决 plt 中文显示的问题
    # plt.rcParams["font.sans-serif"] = ["KaiTi"]
    # plt.rcParams["axes.unicode_minus"] = False
    print(model_path)
    print(f"正确率: {correct / len(label_paths)}")
